import numpy as np

"""
basic dropout
"""


class Dropout:

    def __init__(self, drop_ratio=0.5):
        self.drop_ratio = drop_ratio
        self.flag = None

    def __call__(self, x):
        self.flag = np.random.rand(*x.shape) > self.drop_ratio
        return x * self.flag

    def backward(self, dout):
        return dout * self.flag
